"""Split up step configurations [3d7e39f3ac92].

Revision ID: 3d7e39f3ac92
Revises: 857843db1bcf
Create Date: 2025-06-17 17:45:31.702617

"""

import json
import uuid

import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import mysql

from zenml.utils.time_utils import utc_now

# revision identifiers, used by Alembic.
revision = "3d7e39f3ac92"
down_revision = "857843db1bcf"
branch_labels = None
depends_on = None


def upgrade() -> None:
    """Upgrade database schema and/or data, creating a new revision."""
    # ### commands auto generated by Alembic - please adjust! ###
    op.create_table(
        "step_configuration",
        sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
        sa.Column("created", sa.DateTime(), nullable=False),
        sa.Column("updated", sa.DateTime(), nullable=False),
        sa.Column("index", sa.Integer(), nullable=False),
        sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
        sa.Column(
            "config",
            sa.String(length=16777215).with_variant(mysql.MEDIUMTEXT, "mysql"),
            nullable=False,
        ),
        sa.Column(
            "deployment_id", sqlmodel.sql.sqltypes.GUID(), nullable=False
        ),
        sa.ForeignKeyConstraint(
            ["deployment_id"],
            ["pipeline_deployment.id"],
            name="fk_step_configuration_deployment_id_pipeline_deployment",
            ondelete="CASCADE",
        ),
        sa.PrimaryKeyConstraint("id"),
        sa.UniqueConstraint(
            "deployment_id", "name", name="unique_step_name_for_deployment"
        ),
    )
    with op.batch_alter_table("pipeline_deployment", schema=None) as batch_op:
        batch_op.add_column(
            sa.Column("step_count", sa.Integer(), nullable=True)
        )

    # Migrate existing step configurations
    connection = op.get_bind()
    meta = sa.MetaData()
    meta.reflect(
        bind=connection, only=("pipeline_deployment", "step_configuration")
    )
    pipeline_deployment_table = sa.Table("pipeline_deployment", meta)
    step_configuration_table = sa.Table("step_configuration", meta)

    step_configurations_to_insert = []
    deployment_updates = []

    for deployment_id, step_configurations_json in connection.execute(
        sa.select(
            pipeline_deployment_table.c.id,
            pipeline_deployment_table.c.step_configurations,
        )
    ):
        step_configurations = json.loads(step_configurations_json)

        step_count = len(step_configurations)
        deployment_updates.append(
            {
                "id_": deployment_id,
                "step_count": step_count,
            }
        )

        for index, (step_name, step_config) in enumerate(
            step_configurations.items()
        ):
            now = utc_now()
            step_configurations_to_insert.append(
                {
                    "id": str(uuid.uuid4()).replace("-", ""),
                    "created": now,
                    "updated": now,
                    "index": index,
                    "name": step_name,
                    "config": json.dumps(step_config),
                    "deployment_id": deployment_id,
                }
            )

    op.bulk_insert(
        step_configuration_table, rows=step_configurations_to_insert
    )
    if deployment_updates:
        connection.execute(
            sa.update(pipeline_deployment_table).where(
                pipeline_deployment_table.c.id == sa.bindparam("id_")
            ),
            deployment_updates,
        )

    with op.batch_alter_table("pipeline_deployment", schema=None) as batch_op:
        batch_op.alter_column(
            "step_count", existing_type=sa.Integer(), nullable=False
        )
        batch_op.drop_column("step_configurations")

    # ### end Alembic commands ###


def downgrade() -> None:
    """Downgrade database schema and/or data back to the previous revision."""
    # ### commands auto generated by Alembic - please adjust! ###
    with op.batch_alter_table("pipeline_deployment", schema=None) as batch_op:
        batch_op.add_column(
            sa.Column(
                "step_configurations",
                sa.VARCHAR(length=16777215),
                nullable=False,
            )
        )
        batch_op.drop_column("step_count")

    op.drop_table("step_configuration")
    # ### end Alembic commands ###
